{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 1-3,文本数据建模流程范例"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 一，准备数据\n",
    "\n",
    "imdb数据集的目标是根据电影评论的文本内容预测评论的情感标签。\n",
    "\n",
    "训练集有20000条电影评论文本，测试集有5000条电影评论文本，其中正面评论和负面评论都各占一半。\n",
    "\n",
    "文本数据预处理较为繁琐，包括中文切词（本示例不涉及），构建词典，编码转换，序列填充，构建数据管道等等。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在tensorflow中完成文本数据预处理的常用方案有两种，第一种是利用tf.keras.preprocessing中的Tokenizer词典构建工具和tf.keras.utils.Sequence构建文本数据生成器管道。\n",
    "\n",
    "第二种是使用tf.data.Dataset搭配.keras.layers.experimental.preprocessing.TextVectorization预处理层。\n",
    "\n",
    "第一种方法较为复杂，其使用范例可以参考以下文章。\n",
    "\n",
    "https://zhuanlan.zhihu.com/p/67697840\n",
    "\n",
    "第二种方法为TensorFlow原生方式，相对也更加简单一些。\n",
    "\n",
    "我们此处介绍第二种方法。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![](./data/电影评论.jpg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from matplotlib import pyplot as plt\n",
    "import tensorflow as tf\n",
    "from tensorflow.keras import models, layers, preprocessing, optimizers, losses, metrics\n",
    "from tensorflow.keras.layers.experimental.preprocessing import TextVectorization\n",
    "\n",
    "import re,string"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[b'the', b'and', b'a', b'of', b'to', b'is', b'in', b'it', b'i', b'this', b'that', b'was', b'as', b'for', b'with', b'movie', b'but', b'film', b'on', b'not', b'you', b'his', b'are', b'have', b'be', b'he', b'one', b'its', b'at', b'all', b'by', b'an', b'they', b'from', b'who', b'so', b'like', b'her', b'just', b'or', b'about', b'has', b'if', b'out', b'some', b'there', b'what', b'good', b'more', b'when', b'very', b'she', b'even', b'my', b'no', b'would', b'up', b'time', b'only', b'which', b'story', b'really', b'their', b'were', b'had', b'see', b'can', b'me', b'than', b'we', b'much', b'well', b'get', b'been', b'will', b'into', b'people', b'also', b'other', b'do', b'bad', b'because', b'great', b'first', b'how', b'him', b'most', b'dont', b'made', b'then', b'them', b'films', b'movies', b'way', b'make', b'could', b'too', b'any', b'after', b'characters']\n"
     ]
    }
   ],
   "source": [
    "train_data_path = \"./data/imdb/train.csv\"\n",
    "test_data_path =  \"./data/imdb/test.csv\"\n",
    "\n",
    "MAX_WORDS = 10000  # 仅考虑最高频的10000个词\n",
    "MAX_LEN = 200  # 每个样本保留200个词的长度\n",
    "BATCH_SIZE = 20 \n",
    "\n",
    "#构建管道\n",
    "def splite_line(line):\n",
    "    arr = tf.strings.split(line, \"\\t\")\n",
    "    label = tf.expand_dims(tf.cast(tf.strings.to_number(arr[0]), tf.int32), axis=0)\n",
    "    text = tf.expand_dims(arr[1], axis=0)\n",
    "    return (text, label)\n",
    "\n",
    "# num_parallel_calls函数是为了提高数据map的处理速度, 并行处理,\n",
    "# tf.data.experimental.AUTOTUNE可以让程序自动的选择最优的线程并行个数\n",
    "# prefetch可以让数据队列预先缓存一定个数的batch, 提高对GPU的利用率.\n",
    "ds_train_raw = tf.data.TextLineDataset(filenames=[train_data_path]) \\\n",
    "            .map(splite_line, num_parallel_calls=tf.data.experimental.AUTOTUNE) \\\n",
    "            .shuffle(buffer_size=1000).batch(BATCH_SIZE) \\\n",
    "            .prefetch(tf.data.experimental.AUTOTUNE) \n",
    "\n",
    "ds_test_raw = tf.data.TextLineDataset(filenames=[test_data_path]) \\\n",
    "            .map(splite_line, num_parallel_calls=tf.data.experimental.AUTOTUNE) \\\n",
    "            .batch(BATCH_SIZE) \\\n",
    "            .prefetch(tf.data.experimental.AUTOTUNE)\n",
    "\n",
    "#构建词典\n",
    "def clean_text(text):\n",
    "    lowercase = tf.strings.lower(text)\n",
    "    stripped_html = tf.strings.regex_replace(lowercase, '<br />', ' ')\n",
    "    cleaned_punctuation = tf.strings.regex_replace(stripped_html,\n",
    "         '[%s]' % re.escape(string.punctuation),'')\n",
    "    return cleaned_punctuation\n",
    "\n",
    "vectorize_layer = TextVectorization(\n",
    "    standardize=clean_text,\n",
    "    split='whitespace',\n",
    "    max_tokens=MAX_WORDS-1, #有一个留给占位符\n",
    "    output_mode='int',\n",
    "    output_sequence_length=MAX_LEN)\n",
    "\n",
    "ds_text = ds_train_raw.map(lambda text,label: text)\n",
    "vectorize_layer.adapt(ds_text)\n",
    "print(vectorize_layer.get_vocabulary()[0:100])\n",
    "\n",
    "\n",
    "#单词编码\n",
    "ds_train = ds_train_raw.map(lambda text,label:(vectorize_layer(text),label)) \\\n",
    "    .prefetch(tf.data.experimental.AUTOTUNE)\n",
    "ds_test = ds_test_raw.map(lambda text,label:(vectorize_layer(text),label)) \\\n",
    "    .prefetch(tf.data.experimental.AUTOTUNE)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 二，定义模型\n",
    "\n",
    "使用Keras接口有以下3种方式构建模型：使用Sequential按层顺序构建模型，使用函数式API构建任意结构模型，继承Model基类构建自定义模型。\n",
    "\n",
    "此处选择使用继承Model基类构建自定义模型。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"cnn_model\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "embedding (Embedding)        multiple                  70000     \n",
      "_________________________________________________________________\n",
      "conv_1 (Conv1D)              multiple                  576       \n",
      "_________________________________________________________________\n",
      "max_pooling1d (MaxPooling1D) multiple                  0         \n",
      "_________________________________________________________________\n",
      "conv_2 (Conv1D)              multiple                  4224      \n",
      "_________________________________________________________________\n",
      "flatten (Flatten)            multiple                  0         \n",
      "_________________________________________________________________\n",
      "dense (Dense)                multiple                  12417     \n",
      "=================================================================\n",
      "Total params: 87,217\n",
      "Trainable params: 87,217\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "# 演示自定义模型范例，实际上应该优先使用Sequential或者函数式API\n",
    "\n",
    "tf.keras.backend.clear_session()\n",
    "\n",
    "class CnnModel(models.Model):\n",
    "    def __init__(self):\n",
    "        super(CnnModel, self).__init__()\n",
    "    \n",
    "    def build(self, input_shape):\n",
    "        self.embedding = layers.Embedding(MAX_WORDS, 7, input_length=MAX_LEN)\n",
    "        self.conv_1 = layers.Conv1D(16, kernel_size=5, name='conv_1', activation='relu')\n",
    "        self.pool = layers.MaxPool1D()\n",
    "        self.conv_2 = layers.Conv1D(128, kernel_size=2, name='conv_2', activation='relu')\n",
    "        self.flatten = layers.Flatten()\n",
    "        self.dense = layers.Dense(1, activation='sigmoid')\n",
    "        super(CnnModel, self).build(input_shape)\n",
    "        \n",
    "    def call(self, x):\n",
    "        x = self.embedding(x)\n",
    "        x = self.conv_1(x)\n",
    "        x = self.pool(x)\n",
    "        x = self.conv_2(x)\n",
    "        x = self.flatten(x)\n",
    "        x = self.dense(x)\n",
    "        return (x)\n",
    "\n",
    "model = CnnModel()\n",
    "model.build(input_shape=(None, MAX_LEN))\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 三，训练模型\n",
    "\n",
    "训练模型通常有3种方法，内置fit方法，内置train_on_batch方法，以及自定义训练循环。此处我们通过自定义训练循环训练模型。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "#打印时间分割线\n",
    "@tf.function\n",
    "\n",
    "def printbar():\n",
    "    ts = tf.timestamp()\n",
    "    today_ts = ts % (24*60*60)\n",
    "    hour = tf.cast(today_ts//3600 + 8, tf.int32) % tf.constant(24)\n",
    "    minite = tf.cast((today_ts%3600)//60, tf.int32)\n",
    "    second = tf.cast(tf.floor(today_ts%60), tf.int32)\n",
    "    \n",
    "    def timeformat(m):\n",
    "        if tf.strings.length(tf.strings.format(\"{}\", m))==1:\n",
    "            return (tf.strings.format(\"0{}\", m))\n",
    "        else:\n",
    "            return (tf.strings.format(\"{}\", m))\n",
    "    timestring = tf.strings.join([timeformat(hour), timeformat(minite),\n",
    "                                 timeformat(second)], separator=':')\n",
    "    tf.print(\"==========\"*6,end = \"\")\n",
    "    tf.print(timestring)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "============================================================15:18:23\n",
      "Epoch=1,Loss:0.00793447159,Accuracy:0.9976,Valid Loss:2.22512126,Valid Accuracy:0.8504\n",
      "\n",
      "============================================================15:18:28\n",
      "Epoch=2,Loss:0.00334877172,Accuracy:0.99905,Valid Loss:2.23083282,Valid Accuracy:0.8506\n",
      "\n",
      "============================================================15:18:32\n",
      "Epoch=3,Loss:0.00199980522,Accuracy:0.9993,Valid Loss:2.25204253,Valid Accuracy:0.8518\n",
      "\n",
      "============================================================15:18:37\n",
      "Epoch=4,Loss:0.00476748263,Accuracy:0.99875,Valid Loss:2.21714354,Valid Accuracy:0.8474\n",
      "\n",
      "============================================================15:18:41\n",
      "Epoch=5,Loss:0.00365528697,Accuracy:0.99885,Valid Loss:2.24917,Valid Accuracy:0.8506\n",
      "\n",
      "============================================================15:18:46\n",
      "Epoch=6,Loss:0.00159265078,Accuracy:0.99955,Valid Loss:2.30116558,Valid Accuracy:0.8482\n",
      "\n"
     ]
    }
   ],
   "source": [
    "optimizer = optimizers.Nadam()\n",
    "loss_func = losses.BinaryCrossentropy()\n",
    "\n",
    "train_loss = metrics.Mean(name='train_loss')\n",
    "train_metric = metrics.BinaryAccuracy(name='train_accuracy')\n",
    "\n",
    "valid_loss = metrics.Mean(name='vaild_loss')\n",
    "valid_metric = metrics.BinaryAccuracy(name='valid_accuracy')\n",
    "\n",
    "@tf.function\n",
    "def train_step(model, features, labels):\n",
    "    with tf.GradientTape() as tape:\n",
    "        predictions = model(features, training=True)\n",
    "        loss = loss_func(labels, predictions)\n",
    "    gradents = tape.gradient(loss, model.trainable_variables)\n",
    "    optimizer.apply_gradients(zip(gradents, model.trainable_variables))\n",
    "    \n",
    "    train_loss.update_state(loss)\n",
    "    train_metric.update_state(labels, predictions)\n",
    "    \n",
    "@tf.function\n",
    "def valid_step(model, features, labels):\n",
    "    predictions = model(features, training=False)\n",
    "    batch_loss = loss_func(labels, predictions)\n",
    "    valid_loss.update_state(batch_loss)\n",
    "    valid_metric.update_state(labels, predictions)\n",
    "    \n",
    "def train_model(model, ds_train, ds_valid, epochs):\n",
    "    for epoch in tf.range(1, epochs+1):\n",
    "        for features, labels in ds_train:\n",
    "            train_step(model, features, labels)\n",
    "        \n",
    "        for features, labels in ds_valid:\n",
    "            valid_step(model, features, labels)\n",
    "            \n",
    "        #此处logs模板需要根据metric具体情况修改\n",
    "        logs = 'Epoch={},Loss:{},Accuracy:{},Valid Loss:{},Valid Accuracy:{}' \n",
    "        \n",
    "        if epoch%1 == 0:\n",
    "            printbar()\n",
    "            tf.print(tf.strings.format(logs,\n",
    "                    (epoch, train_loss.result(), train_metric.result(), \n",
    "                    valid_loss.result(), valid_metric.result())))\n",
    "            tf.print(\"\")\n",
    "        train_loss.reset_states()\n",
    "        train_metric.reset_states()\n",
    "        valid_loss.reset_states()\n",
    "        valid_metric.reset_states()\n",
    "\n",
    "train_model(model, ds_train, ds_test, epochs=6)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 四，评估模型"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "通过自定义训练循环训练的模型没有经过编译，无法直接使用model.evaluate(ds_valid)方法"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_model(model,ds_valid):\n",
    "    for features, labels in ds_valid:\n",
    "         valid_step(model,features,labels)\n",
    "    logs = 'Valid Loss:{},Valid Accuracy:{}' \n",
    "    tf.print(tf.strings.format(logs,(valid_loss.result(),valid_metric.result())))\n",
    "    \n",
    "    valid_loss.reset_states()\n",
    "    train_metric.reset_states()\n",
    "    valid_metric.reset_states()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Valid Loss:1.75198472,Valid Accuracy:0.858\n"
     ]
    }
   ],
   "source": [
    "evaluate_model(model,ds_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 五，使用模型\n",
    "\n",
    "可以使用以下方法:\n",
    "\n",
    "* model.predict(ds_test)\n",
    "* model(x_test)\n",
    "* model.call(x_test)\n",
    "* model.predict_on_batch(x_test)\n",
    "\n",
    "推荐优先使用model.predict(ds_test)方法，既可以对Dataset，也可以对Tensor使用。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[7.6764333e-03],\n",
       "       [1.0000000e+00],\n",
       "       [1.0000000e+00],\n",
       "       ...,\n",
       "       [1.0000000e+00],\n",
       "       [4.7424770e-04],\n",
       "       [1.0000000e+00]], dtype=float32)"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.predict(ds_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor(\n",
      "[[7.6764370e-03]\n",
      " [1.0000000e+00]\n",
      " [1.0000000e+00]\n",
      " [1.0597356e-34]\n",
      " [1.2959871e-13]\n",
      " [4.5892104e-08]\n",
      " [1.0962391e-21]\n",
      " [1.6840951e-12]\n",
      " [1.0000000e+00]\n",
      " [9.9999976e-01]\n",
      " [1.0000000e+00]\n",
      " [9.9990106e-01]\n",
      " [1.2506935e-14]\n",
      " [2.6104510e-01]\n",
      " [1.4092944e-13]\n",
      " [4.5176619e-01]\n",
      " [1.2707364e-09]\n",
      " [9.4303779e-08]\n",
      " [2.6665192e-15]\n",
      " [1.0000000e+00]], shape=(20, 1), dtype=float32)\n"
     ]
    }
   ],
   "source": [
    "for x_test,_ in ds_test.take(1):\n",
    "    print(model(x_test))\n",
    "    #以下方法等价：\n",
    "    #print(model.call(x_test))\n",
    "    #print(model.predict_on_batch(x_test))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 六，保存模型\n",
    "\n",
    "推荐使用TensorFlow原生方式保存模型。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: ./data/tf_model_savedmodel/assets\n",
      "export saved model.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "array([[7.6764333e-03],\n",
       "       [1.0000000e+00],\n",
       "       [1.0000000e+00],\n",
       "       ...,\n",
       "       [1.0000000e+00],\n",
       "       [4.7424770e-04],\n",
       "       [1.0000000e+00]], dtype=float32)"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.save('./data/tf_model_savedmodel', save_format=\"tf\")\n",
    "print('export saved model.')\n",
    "\n",
    "model_loaded = tf.keras.models.load_model('./data/tf_model_savedmodel/')\n",
    "model_loaded.predict(ds_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
